//-*-C++-*-
/***************************************************************************
 *
 *   Copyright (C) 2006-2025 by Willem van Straten and Will Gauvin
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

// dspsr/Signal/General/dsp/Rescale.h

#ifndef __dsp_Rescale_h
#define __dsp_Rescale_h

#include "dsp/Transformation.h"
#include "dsp/TimeSeries.h"

#include <vector>

namespace dsp
{
  /**
   * @brief Rescale all channels and polarisations independently to a zero mean and unit variance.
   */
  class Rescale : public Transformation<TimeSeries,TimeSeries>
  {

  public:

    //! Default constructor
    Rescale ();

    //! Destructor
    ~Rescale () = default;

    void prepare ();

    //! Rescale to zero mean and unit variance
    void transformation ();

    /**
     * @brief set the rescaling interval in seconds.
     *
     * The number of samples used in the rescaling will be (seconds * input->get_rate())
     *
     * @param seconds the number of seconds to be used in rescaling calculations.
     */
    void set_interval_seconds (double seconds);

    /**
     * @brief set the number of samples used during rescaling calculations.
     *
     * @param samples the number of samples to use for calculating and applying scales and offsets.
     */
    void set_interval_samples (uint64_t samples);

    /**
     * @brief set indicator of whether exactly the number of samples set in set_interval_samples should be used.
     *
     * If exact is true then InputBuffering is used to ensure that at least nsamples are available before
     * the transformation is applied to input data.
     *
     * @param exact indicator of whether the exact number of sample intervals are used or not.
     * @throws Error if exact == true and set_interval_samples has not been called.
     */
    void set_exact (bool exact);

    /**
     * @brief set whether to use a constant offset and scale after it has been calculated the first time.
     *
     * @param constant_offset_scale whether to use a constant offset and scale after it has been calculated the first time.
     */
    void set_constant (bool constant_offset_scale);

    /**
     * @brief set whether to subtract an exponential smooth with specified decay constant.
     *
     * @param decay_constant the decay constant to use for exponential smoothing.
     */
    void set_decay (float decay_constant);

    //! Do not output any data before the first integration interval has passed
    void set_output_after_interval (bool output_after_interval);

    //! Maintain fscrunched total that can be output
    void set_output_time_total (bool output_time_total);

    //! Get the epoch of the last scale/offset update
    MJD get_update_epoch () const;

    //! Get the spectrum of offsets for the given polarization
    const float* get_offset (unsigned ipol) const;

    //! Get the spectrum of scales for the given polarization
    const float* get_scale (unsigned ipol) const;

    /**
     * @brief Get the spectrum of integrated input values for the given polarization
     *
     * Integration is performed across time samples.
     * For detected inputs this spectrum is the integrated bandpass.
     */
    const double* get_mean (unsigned ipol) const;

    /**
     * @brief Get the spectrum in integrated squared input values for the given polarization
     *
     * Integration is performed across time samples.
     * For undetected inputs this spectrum is the integrated bandpass.
     */
    const double* get_variance (unsigned ipol) const;

    //! Get the number of samples between updates
    uint64_t get_nsample () const;

    //! Get the total power time series for the given polarization
    const float* get_time (unsigned ipol) const;

    class Engine;
    class ScaleOffsetCalculator;

    //! Callback to use when scales and offsets have been computed
    Callback<Rescale*> scales_updated;

    /**
     * @brief set the Engine used to perform the rescale transformation.
     *
     * If no Engine is set then the Rescale class will used the default implementation on the CPU.
     *
     * @param engine a pointer to the Engine to use to perform the rescale transformation.
     */
    void set_engine(Engine *);

    /**
     * @brief set the ScaleOffsetCalculator to use when calculating the scales and offsets
     *
     * The default calculator used is the RescaleMeanStdCalculator that calculates the mean and standard deviation.
     *
     * @param calculator a pointer to the ScaleOffsetCalculator to use when calculating the scales and offsets
     */
    void set_calculator(ScaleOffsetCalculator* _calculator) { calculator = _calculator; }

    /**
     * @brief get the ScaleOffsetCalculator used in calculating the scales and offsets.
     */
    const ScaleOffsetCalculator* get_calculator() const { return calculator; }

  private:

    //! normalisation decay offsets, ordered by [pol][chan]
    std::vector< std::vector<float> > decay_offset;

    //! flag that ensures an exact number of samples, interval_samples, is used to compute offsets and scales
    bool exact{false};

    //! flag that enables the calculation of the time_total vector
    bool output_time_total{false};

    //! flag to prevent output of any data before the first integration interval has passed [not used]
    bool output_after_interval{false};

    //! length of data to use, in seconds, when computing the offsets and scales. Ignored if interval_samples > 0
    double interval_seconds{0.0};

    //! length of data to use, in samples, when computing the offsets and scales
    uint64_t interval_samples{0};

    //! decay constant to use, if do_decay true,
    float decay_constant{1e4};

    //! flag to control the use of the decay constant
    bool do_decay{false};

    //! the number of time samples the transformation will assess when computing statistics
    uint64_t nsample{0};

    //! counter for the number of processed time samples
    uint64_t isample{0};

    MJD update_epoch{};

    //! flag that holds the offset and scale constant, after the first calculation
    bool constant_offset_scale{false};

    //! flag to track whether the first integration has been computed
    bool first_integration{false};

    //! allocated and initialise attributes for the transformation.
    void init ();

    //! compute the offset and scale factors from the freq_total and freq_totalsq vectors
    void compute_various (bool first_call = false, bool first_integration = false);

    //! Interface to alternate processing engine (e.g. GPU)
    Reference::To<Engine> engine;

    //! Reference to a scales and offset calculator
    Reference::To<ScaleOffsetCalculator> calculator;
  };

  class Rescale::Engine : public Reference::Able
  {
  public:
    /**
     * @brief initialise the the engine.
     *
     * @param input a pointer to the input timeseries
     * @param nsample the number of samples used to calculate the scale and offsets for.
     * @param exact an indicator of whether to have exact number of samples to calculate the scales and offsets.
     * @param constant_offset_scale an indicator of whether to only calculate the scale and offset once or each time
     *    the transform method is called.
     */
    virtual void init(const dsp::TimeSeries *input, uint64_t nsample, bool exact, bool constant_offset_scale) = 0;

    /**
     * @brief perform the transformation of the input timeseries data and write out to the given output timeseries
     *
     * @param input a pointer to the input timeseries data.
     * @param output a pointer to the output timeseries data.
     */
    virtual void transform(const dsp::TimeSeries *input, dsp::TimeSeries *output) = 0;

    /**
     * @brief get the spectrum of offsets for the given polarization
     *
     * @param ipol the polarisation index to get spectrum of offsets for.
     *
     * @returns the spectrum of offsets for the given polarization
     */
    virtual const float *get_offset(unsigned ipol) const = 0;

    /**
     * @brief get the spectrum of scales for the given polarization
     *
     * @param ipol the polarisation index to get spectrum of scales for.
     *
     * @returns the spectrum of scales for the given polarization
     */
    virtual const float *get_scale(unsigned ipol) const = 0;

    /**
     * @brief Get the spectrum of integrated input values for the given polarization
     *
     * Integration is performed across nsample time samples.
     *
     * @param ipol the polarisation index to get spectrum of integrated inputs values for.
     * @returns the spectrum of integrated input values for the given polarization.
     */
    virtual const double *get_freq_total(unsigned ipol) const = 0;

    /**
     * @brief Get the spectrum in integrated input values squared for the given polarization
     *
     * Integration is performed across the nsample time samples.
     *
     * @param ipol the polarisation index to get spectrum of the total input values squared for the current nsample.
     * @returns the spectrum in integrated input values squared for the given polarization.
     */
    virtual const double *get_freq_squared_total(unsigned ipol) const = 0;
  };

  /**
   * @brief an abstract class that is used to calculate the scales and offset of the data
   *
   */
  class Rescale::ScaleOffsetCalculator : public Reference::Able {
    public:
      //! Default constructor
      ScaleOffsetCalculator() = default;

      //! Default destructor
      virtual ~ScaleOffsetCalculator() = default;

      /**
       * @brief initialise the calculator to allow for allocation of arrays and/or scratch space
       *
       * @param input the input timeseries of the Rescale operation, needed to get correct nchan, npol and ndim
       * @param ndat the number of samples to be used to calculate the scales and offset for
       * @param output_time_total flag that instructs the Calculator to compute an F-scrunch time series of rescaled data
       */
      virtual void init(const dsp::TimeSeries* input, uint64_t ndat, bool output_time_total) = 0;

      /**
       * @brief used to sample the current input timeseries to allow calculation of scales and offsets
       *
       * Implementations of this should not assume the total number of samples (end_dat - start_dat) is equal
       * to the ndat in the init method.  This is due to how the Rescale class works and there might be a few
       * iterations of sampling data before the total ndat is achieved.
       *
       * @param input the input timeseries to sample the data for, it may be either FPT or TFP ordered data.
       * @param start_dat the start time sample, may not be 0
       * @param end_dat the end time sample, may not be input->get_ndat() but may be less than that.
       * @param output_time_total an indicator of whether to accumulate the time sample values.
       * @returns the ending time sample (should be end_dat)
       */
      virtual uint64_t sample_data(const dsp::TimeSeries* input, uint64_t start_dat, uint64_t end_dat, bool output_time_total) = 0;

      /**
       * @brief calculate the scale and offset values based on the sampled data
       *
       * @param nsample the number of data samples to use when calculating the statistics, this will either
       *  be the ndat passed in during the call to the init method or the size of ndat from the input timeseries.
       */
      virtual void compute(uint64_t nsample) = 0;

      /**
       * @brief reset the accumulated sample data to allow recalculation of statistics later.
       */
      virtual void reset_sample_data() = 0;

      /**
       * @brief get a constant reference to the calculated scales
       */
      const std::vector<std::vector<float>>& get_scales() const { return scale; }

      /**
       * @brief get a constant reference to the calculated offsets
       */
      const std::vector<std::vector<float>>& get_offsets() const { return offset; }

      /**
       * @brief get a pointer to the scales for the the given polarisation
       */
      virtual const float* get_scale (unsigned ipol) const;

      /**
       * @brief get a pointer to the offsets for the the given polarisation
       */
      virtual const float* get_offset (unsigned ipol) const;

      /**
       * @brief get a pointer to the accumulated time samples for given polarisation
       */
      virtual const float* get_time (unsigned ipol) const;

      /**
       * @brief get a pointer to the mean values for the channels of given polarisations
       */
      virtual const double* get_mean (unsigned ipol) const = 0;

      /**
       * @brief get a pointer to the variance values for the channels of given polarisations
       */
      virtual const double* get_variance (unsigned ipol) const = 0;

    protected:
      //! The number of samples to use to calculate the scales and offset over
      uint64_t ndat{0};

      //! The number of frequency channels to calculate the scales and offsets for
      unsigned nchan{0};

      //! The number of polarisations, per channel, to calculate the scales and offsets for
      unsigned npol{0};

      //! The number of dimensions per sample value
      unsigned ndim{0};

      //! normalisation scale, ordered by [pol][chan]
      std::vector< std::vector<float>> scale;

      //! normalisation offset, ordered by [pol][chan]
      std::vector< std::vector<float>> offset;

      //! normalisation decay offsets, ordered by [pol][chan]
      std::vector< std::vector<float>> decay_offset;

      //! normalisation decay offsets, ordered by [pol][chan]
      std::vector< std::vector<float>> time_total;
  };

} // namespace dsp

#endif // __dsp_Rescale_h
